// Copyright 2015 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "sparrow/base64.h"

#include <stddef.h>
#include <algorithm>
#include <string>

#include "src/base64/modp_b64.h"

extern "C" {

constexpr char kPaddingChar = '=';

// Base64url maps {+, /} to {-, _} in order for the encoded content to be safe
// to use in a URL. These characters will be translated by this implementation.
static const char kBase64Chars[] = "+/";
static const char kBase64UrlSafeChars[] = "-_";

static void ReplaceChars(struct FlexibleArray *input, char old_char, char new_char) {
    for (int i = 0; i < input->length; i++) {
        if (input->data[i] == old_char) {
            input->data[i] = new_char;
        }
    }
}

static void ReplaceChars(std::string &input, char old_char, char new_char) {
    int count = static_cast<int>(input.size());
    for (int i = 0; i < count; i++) {
        if (input[i] == old_char) {
            input[i] = new_char;
        }
    }
}

void Base64UrlEncode(const void *input, int input_size, Base64UrlEncodePolicy policy,
                     struct FlexibleArray **output) {
    Base64Encode(input, input_size, output);

    ReplaceChars(*output, '+', '-');
    ReplaceChars(*output, '/', '_');

    switch (policy) {
    case Base64UrlEncodePolicy::INCLUDE_PADDING:
        // The padding included in |*output| will not be amended.
        break;
    case Base64UrlEncodePolicy::OMIT_PADDING:
        // The padding included in |*output| will be removed.
        // const size_t last_non_padding_pos = output->find_last_not_of(kPaddingChar);
        // if (last_non_padding_pos != std::string::npos) output->resize(last_non_padding_pos + 1);
        int last_non_padding_pos = -1;
        for (int i = (*output)->length - 1; i >= 0; i--) {
            if ((*output)->data[i] != kPaddingChar) {
                last_non_padding_pos = i;
                break;
            }
        }
        if (last_non_padding_pos != -1) {
            (*output)->length = last_non_padding_pos + 1;
        }
        break;
    }
}

int Base64UrlDecode(const void *input_ptr, int input_size, Base64UrlDecodePolicy policy,
                    struct FlexibleArray **output) {
    // Characters outside of the base64url alphabet are disallowed, which includes
    // the {+, /} characters found in the conventional base64 alphabet.
    *output = NULL;
    std::string input(static_cast<const char *>(input_ptr), input_size);
    if (input.find_first_of(kBase64Chars) != std::string::npos) return -1;
    const size_t required_padding_characters = input_size % 4;
    const bool needs_replacement = input.find_first_of(kBase64UrlSafeChars) != std::string::npos;

    switch (policy) {
    case Base64UrlDecodePolicy::REQUIRE_PADDING:
        // Fail if the required padding is not included in |input|.
        if (required_padding_characters > 0) return -1;
        break;
    case Base64UrlDecodePolicy::IGNORE_PADDING:
        // Missing padding will be silently appended.
        break;
    case Base64UrlDecodePolicy::DISALLOW_PADDING:
        // Fail if padding characters are included in |input|.
        if (input.find_first_of(kPaddingChar) != std::string::npos) return -1;
        break;
    }

    // If the string either needs replacement of URL-safe characters to normal
    // base64 ones, or additional padding, a copy of |input| needs to be made in
    // order to make these adjustments without side effects.
    if (required_padding_characters > 0 || needs_replacement) {
        std::string base64_input;

        std::size_t base64_input_size = input.size();
        if (required_padding_characters > 0) base64_input_size += 4 - required_padding_characters;

        base64_input.reserve(base64_input_size);
        base64_input.append(input.data(), input.size());

        // Substitute the base64url URL-safe characters to their base64 equivalents.
        ReplaceChars(base64_input, '-', '+');
        ReplaceChars(base64_input, '_', '/');

        // Append the necessary padding characters.
        base64_input.resize(base64_input_size, '=');

        return Base64Decode(base64_input.data(), base64_input.size(), output);
    }

    return Base64Decode(input_ptr, input_size, output);
}

} // extern "C"
